Школа глубокого обучения ФПМИ МФТИ
Домашнее задание. Сегментация изображений
В этом задании вам предстоит решить задачу сегментации медицинских снимков. Домашнее задание можно разделить на следующие части:
- Построй свой первый бейзлайн! [6]
- BCE Loss [2]
- SegNet [2]
- Train [1]
- Test [1]
- Мир других лоссов! [2]
- Dice Loss [1]
- Focal Loss [1]
- BONUS: лосс из статьи [5]
- Новая модель! [2]
- UNet [2]
Максимальный балл: 10 баллов.
Также для студентов желающих еще более углубиться в задачу предлагается решить бонусное задание, которое даст дополнительные 5 баллов. BONUS задание необязательное.
Шаг 1. Загрузка и подготовка данных¶
- Для начала мы скачаем датасет: ADDI project.
- Разархивируем .rar файл.
- Обратите внимание, что папка
PH2 Dataset imagesдолжна лежать там же где и ipynb notebook.
Это фотографии двух типов поражений кожи: меланома и родинки. В данном задании мы не будем заниматься их классификацией, а будем сегментировать их.
!gdown 1T_RPkPP0jeWwK8L1UrmBw8V30eD7v6Ql
Downloading... From (original): https://drive.google.com/uc?id=1T_RPkPP0jeWwK8L1UrmBw8V30eD7v6Ql From (redirected): https://drive.google.com/uc?id=1T_RPkPP0jeWwK8L1UrmBw8V30eD7v6Ql&confirm=t&uuid=2085d32f-a224-4c84-b7a5-6dfb511a68d1 To: /content/PH2Dataset.rar 100% 162M/162M [00:01<00:00, 125MB/s]
get_ipython().system_raw("unrar x PH2Dataset.rar")
Стуктура датасета у нас следующая:
IMD_002/
IMD002_Dermoscopic_Image/
IMD002.bmp
IMD002_lesion/
IMD002_lesion.bmp
IMD002_roi/
...
IMD_003/
...
...
Здесь X.bmp — изображение, которое нужно сегментировать, X_lesion.bmp — результат сегментации.
Для загрузки датасета можно использовать skimage: skimage.io.imread()
images = []
lesions = []
from skimage.io import imread
import os
root = 'PH2Dataset'
for root, dirs, files in os.walk(os.path.join(root, 'PH2 Dataset images')):
if root.endswith('_Dermoscopic_Image'):
images.append(imread(os.path.join(root, files[0])))
if root.endswith('_lesion'):
lesions.append(imread(os.path.join(root, files[0])))
Изображения имеют разные размеры. Давайте изменим их размер на $256\times256 $ пикселей. Для изменения размера изображений можно использовать skimage.transform.resize().
Эта функция также автоматически нормализует изображения в диапазоне $[0,1]$.
from skimage.transform import resize
size = (256, 256)
X = [resize(x, size, mode='constant', anti_aliasing=True,) for x in images]
Y = [resize(y, size, mode='constant', anti_aliasing=False) > 0.5 for y in lesions]
import numpy as np
X = np.array(X, np.float32)
Y = np.array(Y, np.float32)
print(f'Loaded {len(X)} images')
Loaded 200 images
Чтобы убедиться, что все корректно, мы нарисуем несколько изображений
import matplotlib.pyplot as plt
from IPython.display import clear_output
plt.figure(figsize=(18, 6))
for i in range(6):
plt.subplot(2, 6, i+1)
plt.axis("off")
plt.imshow(X[i])
plt.subplot(2, 6, i+7)
plt.axis("off")
plt.imshow(Y[i])
plt.show();
Разделим наши 200 картинок на 100/50/50 для обучения, валидации и теста соответственно
ix = np.random.choice(len(X), len(X), False)
tr, val, ts = np.split(ix, [100, 150])
print(len(tr), len(val), len(ts))
100 50 50
PyTorch DataLoader¶
from torch.utils.data import DataLoader
batch_size = 25
train_dataloader = DataLoader(list(zip(np.rollaxis(X[tr], 3, 1), Y[tr, np.newaxis])),
batch_size=batch_size, shuffle=True)
valid_dataloader = DataLoader(list(zip(np.rollaxis(X[val], 3, 1), Y[val, np.newaxis])),
batch_size=batch_size, shuffle=False)
test_dataloader = DataLoader(list(zip(np.rollaxis(X[ts], 3, 1), Y[ts, np.newaxis])),
batch_size=batch_size, shuffle=False)
loaders = {'train':train_dataloader, 'val': valid_dataloader}
import torch
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(DEVICE)
cuda
torch.manual_seed(42)
torch.cuda.manual_seed(42)
torch.cuda.manual_seed_all(42)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
Шаг 2. Метрика качества модели¶
IoU (intersection over union)¶
В данном разделе предлагается использовать следующую метрику для оценки качества:
$I o U=\frac{\text {target } \cap \text { prediction }}{\text {target } \cup{prediction }}$
Пересечение (A ∩ B) состоит из пикселей, найденных как в маске предсказания, так и в основной маске истины, тогда как объединение (A ∪ B) просто состоит из всех пикселей, найденных либо в маске предсказания, либо в целевой маске.
Что будет являться пересением и объединением в задаче сегментации?
Давайте разберем следующий пример:
!pip install -q torchmetrics
from torchmetrics import JaccardIndex
iou_score = JaccardIndex(threshold=0.5, task="binary", average='none').to(DEVICE)
Задания: Построй свой первый бейзлайн!¶
Итак, загрузка файлов, код датасета и даталоадера написана за вас. Метрика IoU написана за вас! Вам остается написать лосс, модель и функции обучения и теста модели.
- Построй свой первый бейзлайн! [6]
- BCE Loss [2]
- SegNet [2]
- Train [1]
- Test [1]
Шаг 3. Loss функция - BCE [2 балла]¶
Популярным лоссом для бинарной сегментации является бинарная кросс-энтропия, которая задается следующим образом:
$$\mathcal L_{BCE}(y, \hat y) = -\sum_i \left[y_i\log\sigma(\hat y_i) + (1-y_i)\log(1-\sigma(\hat y_i))\right] \space [1]$$
где $y$ это таргет желаемого результата и $\hat y$ является выходом модели. $\sigma$ - это логистическая функция, который преобразует действительное число $\mathbb R$ в вероятность $[0,1]$.
Однако эта потеря страдает от проблем численной нестабильности. Самое главное, что $\lim_{x\rightarrow0}\log(x)=\infty$ приводит к неустойчивости в процессе оптимизации. Рекомендуется посмотреть следующее упрощение. Эта функция эквивалентна первой и не так подвержена численной неустойчивости:
$$\mathcal L_{BCE} = \hat y - y\hat y + \log\left(1+\exp(-\hat y)\right) \space [2]$$
Вывод численно стабильной формулы BCE лосса [1 балл]¶
Выведите из формулы [1] формулу [2]:
$$\mathcal L_{BCE}(y, \hat y) = -\sum_i \left[y_i\log\sigma(\hat y_i) + (1-y_i)\log(1-\sigma(\hat y_i))\right] \space [1]$$
$$\mathcal L_{BCE} = \hat y - y\hat y + \log\left(1+\exp(-\hat y)\right) \space [2]$$
Не забываем, что здесь $\hat y_i$ - это логиты сети, не вероятности и не лейблы.
Ответ:
$$\mathcal \log\sigma(\hat y_i) = \log\frac{1}{1+e^{-\hat y_i}}=\log1 - \log(1+e^{-\hat y_i}) = - \log(1+e^{-\hat y_i})$$
$$\mathcal \log(1 - \sigma(\hat y_i)) = \log(1 - \frac{1}{1+e^{-\hat y_i}}) = \log\frac{e^{-\hat y_i}}{1+e^{-\hat y_i}} = \log(e^{-\hat y_i}) - \log(1+e^{-\hat y_i}) = {-\hat y_i} - \log(1+e^{-\hat y_i})$$
$$\mathcal L_{BCE}(y, \hat y) = -\sum_i \left[y_i\log\sigma(\hat y_i) + (1-y_i)\log(1-\sigma(\hat y_i))\right] = - \sum_i \left[y_i( - \log(1+e^{-\hat y_i})) + (1-y_i)({-\hat y_i} - \log(1+e^{-\hat y_i}))\right]=- \sum_i \left[-y_i \log(1+e^{-\hat y_i}) -{\hat y_i} - \log(1+e^{-\hat y_i}) + y_i{\hat y_i} + y_i \log(1+e^{-\hat y_i})\right] = \sum_i \left[{\hat y_i} - y_i{\hat y_i} + \log(1+e^{-\hat y_i}) \right]$$
Реализуйте в коде оба варианта лосса [1 балл]¶
Реализуйте следующие функции:
bce_true()- честная прямая реализация лосса с формулой $$\mathcal L_{BCE}(y, \hat y) = -\sum_i \left[y_i\log\sigma(\hat y_i) + (1-y_i)\log(1-\sigma(\hat y_i))\right].$$bce_loss()- реализация формулы, которую мы вывели $$\mathcal L_{BCE} = \hat y - y\hat y + \log\left(1+\exp(-\hat y)\right).$$
И сравните результаты функций с реализацией Pytorch:
bce_torch()bce_torch_with_logits()
import torch.nn.functional as F
import torch.nn as nn
bce_torch = nn.BCELoss(reduction='sum') # (sigmoid(y_pred), y_real)
bce_torch_with_logits = nn.BCEWithLogitsLoss(reduction='sum')
def bce_loss(y_pred, y_real):
return torch.sum(y_pred - y_pred*y_real + torch.log(1 + torch.exp(-y_pred)))
def bce_true(y_pred, y_real):
y = y_real
p = 1 / (1 + torch.exp(-y_pred))
return -torch.sum(y * torch.log(p) + (1 - y) * torch.log(1 - p))
Проверим корректность работы на простом примере
y_pred = torch.randn(3, 2, requires_grad=False)
y_true = torch.rand(3, 2, requires_grad=False)
print(f'BCE loss from scratch bce_loss = {bce_loss(y_pred, y_true)}')
print(f'BCE loss честно посчитанный = {bce_true(y_pred, y_true)}')
print(f'BCE loss from torch bce_torch = {bce_torch(torch.sigmoid(y_pred), y_true)}')
print(f'BCE loss from torch with logits bce_torch = {bce_torch_with_logits(y_pred, y_true)}')
BCE loss from scratch bce_loss = 4.616026401519775 BCE loss честно посчитанный = 4.616025924682617 BCE loss from torch bce_torch = 4.616025924682617 BCE loss from torch with logits bce_torch = 4.616025924682617
Инструкции assert в Python — это булевы выражения, которые проверяют, является ли условие истинным (True). Внизу в коде мы проверяем функция bce_loss() выдает тот же результат, что и функция из Pytorch или нет. Если равенства не будет, что будет означать, что результаты функций не совпадают, а значит вы неправильно реализовали фукнцию bce_loss(), assert возвратит ошибку.
Функция numpy.isclose() используется для сравнения двух чисел с учётом допустимой погрешности. Она особенно полезна при работе с числами с плавающей точкой, где точное сравнение может быть проблематичным из-за ограничений представления таких чисел в компьютере.
Как она работает?
numpy.isclose(a, b, rtol=1e-05, atol=1e-08) принимает два числа (a и b) и сравнивает их, учитывая относительную и абсолютную погрешность. Если разница между двумя числами меньше заданного порога, функция возвращает True, иначе — False.
Параметры:
rtol: Относительная погрешность (по умолчанию 1e-05). Используется для определения разницы относительно большего значения.
atol: Абсолютная погрешность (по умолчанию 1e-08). Определяет минимальную разницу, которую следует учитывать.
Мы будем использовать assert и numpy.isclose() для проверки корректности нашего кода.
assert np.isclose(bce_loss(y_pred, y_true), bce_torch(torch.sigmoid(y_pred), y_true))
assert np.isclose(bce_loss(y_pred, y_true), bce_torch_with_logits(y_pred, y_true))
assert np.isclose(bce_true(y_pred, y_true), bce_torch(torch.sigmoid(y_pred), y_true))
assert np.isclose(bce_true(y_pred, y_true), bce_torch_with_logits(y_pred, y_true))
Давайте теперь посчитаем на простом примере, но с теми же размерностями, что и в датасете
y_pred = torch.randn((2, 1, 3, 3), requires_grad=False)
y_true = torch.randint(0, 2, (2, 1, 3, 3))
print(f'BCE loss from scratch bce_loss = {bce_loss(y_pred, y_true)}')
print(f'BCE loss честно посчитанный = {bce_true(y_pred, y_true)}')
print(f'BCE loss from torch bce_torch = {bce_torch(torch.sigmoid(y_pred), y_true.to(torch.float))}')
print(f'BCE loss from torch with logits bce_torch = {bce_torch_with_logits(y_pred, y_true.to(torch.float))}')
BCE loss from scratch bce_loss = 14.737800598144531 BCE loss честно посчитанный = 14.737801551818848 BCE loss from torch bce_torch = 14.737801551818848 BCE loss from torch with logits bce_torch = 14.737800598144531
assert np.isclose(bce_loss(y_pred, y_true), bce_torch(torch.sigmoid(y_pred), y_true.to(torch.float)))
assert np.isclose(bce_loss(y_pred, y_true), bce_torch_with_logits(y_pred, y_true.to(torch.float)))
assert np.isclose(bce_true(y_pred, y_true), bce_torch(torch.sigmoid(y_pred), y_true.to(torch.float)))
assert np.isclose(bce_true(y_pred, y_true), bce_torch_with_logits(y_pred, y_true.to(torch.float)))
Давайте посчитаем на реальных логитах и сегментационной маске:
!gdown --folder 1EX0RW1TRQVkLmR1h6miCQqyhYPFyg28M
Retrieving folder contents Processing file 1--WxvBdpMn_NOmYPf3a4au8MHzfx5baC labels.pt Processing file 1-0A7_CS_vKiSCkgIDJ4joThCEcFedA3I logits.pt Retrieving folder contents completed Building directory structure Building directory structure completed Downloading... From: https://drive.google.com/uc?id=1--WxvBdpMn_NOmYPf3a4au8MHzfx5baC To: /content/for_asserts/labels.pt 100% 1.18k/1.18k [00:00<00:00, 5.16MB/s] Downloading... From: https://drive.google.com/uc?id=1-0A7_CS_vKiSCkgIDJ4joThCEcFedA3I To: /content/for_asserts/logits.pt 100% 1.18k/1.18k [00:00<00:00, 519kB/s] Download completed
path_to_dummy_samples = '/content/for_asserts'
dummpy_sample = {'logits': torch.load(f'{path_to_dummy_samples}/logits.pt'),
'labels': torch.load(f'{path_to_dummy_samples}/labels.pt')}
dummpy_sample['labels'] = dummpy_sample['labels'].to(DEVICE)
dummpy_sample['logits'] = dummpy_sample['logits'].to(DEVICE)
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize= (10,3*10))
ax1.imshow(dummpy_sample['labels'].squeeze(1)[0].cpu())
ax1.set_title("Original")
ax2.imshow(dummpy_sample['logits'].sigmoid().squeeze(1)[0].cpu())
for (j,i),label in np.ndenumerate(dummpy_sample['logits'].sigmoid().squeeze(1)[0].cpu()):
if label < 0.5:
color = 'white'
else:
color = 'black'
ax2.text(i,j,round(label,3), color=color, ha='center',va='center')
ax2.set_title("Predicted Probabilities")
ax3.imshow((dummpy_sample['logits'].sigmoid() > 0.5).squeeze(1)[0].cpu())
ax3.set_title("Predicted Mask")
plt.show()
Проверяем на данном примере:
bce_loss_score = bce_loss(dummpy_sample['logits'].cpu(), dummpy_sample['labels'].cpu())
bce_true_score = bce_true(dummpy_sample['logits'].cpu(), dummpy_sample['labels'].cpu())
bce_torch_score = bce_torch(torch.sigmoid(dummpy_sample['logits'].cpu()), dummpy_sample['labels'].cpu().float())
bce_torch_with_logits_score = bce_torch_with_logits(dummpy_sample['logits'].cpu(), dummpy_sample['labels'].cpu().float())
assert np.isclose(bce_loss_score, bce_torch_score)
assert np.isclose(bce_loss_score, bce_torch_with_logits_score)
assert np.isclose(bce_true_score, bce_torch_score)
assert np.isclose(bce_true_score, bce_torch_with_logits_score)
Шаг 4. Модель SegNet [2 балла]¶
Ваше задание здесь состоит в том, чтобы реализовать SegNet архитектуру.
- Badrinarayanan, V., Kendall, A., & Cipolla, R. (2015). SegNet: A deep convolutional encoder-decoder architecture for image segmentation
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
import torch.optim as optim
from time import time
from matplotlib import rcParams
rcParams['figure.figsize'] = (15,4)
Внимательно посмотрите из чего состоит модель и для чего выбраны те или иные блоки. Для этого скачаем и изучим feature extractor VGG-16, который лежит в основе SegNet.
model_vgg16 = models.vgg16(weights = models.VGG16_Weights.IMAGENET1K_V1)
model_vgg16
VGG(
(features): Sequential(
(0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU(inplace=True)
(2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(3): ReLU(inplace=True)
(4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(6): ReLU(inplace=True)
(7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(8): ReLU(inplace=True)
(9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(11): ReLU(inplace=True)
(12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(13): ReLU(inplace=True)
(14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(15): ReLU(inplace=True)
(16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(18): ReLU(inplace=True)
(19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(20): ReLU(inplace=True)
(21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(22): ReLU(inplace=True)
(23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(25): ReLU(inplace=True)
(26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(27): ReLU(inplace=True)
(28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(29): ReLU(inplace=True)
(30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
(classifier): Sequential(
(0): Linear(in_features=25088, out_features=4096, bias=True)
(1): ReLU(inplace=True)
(2): Dropout(p=0.5, inplace=False)
(3): Linear(in_features=4096, out_features=4096, bias=True)
(4): ReLU(inplace=True)
(5): Dropout(p=0.5, inplace=False)
(6): Linear(in_features=4096, out_features=1000, bias=True)
)
)
Feature extractor VGG-16 состоит из 5 блоков:
- два блока со структурой: Conv2d -> ReLU -> Conv2d -> ReLU -> MaxPool2d
- три блока со структурой: Conv2d -> ReLU -> Conv2d -> ReLU -> Conv2d -> ReLU -> MaxPool2d
В первом блоке - на входе три канала (по числу каналов в изображениях), которые конволюционный слой преобразует в 64 канала.
Во втором, третьем и четвертом блоках первый конволюционный слой удваивает количество каналов, а последующие конволюционные слои не меняют количество каналов.
В последнем блоке число каналов от слоя к слою не меняется.
Теперь напишем код одного блока энкодера нашей модели SegNet.
# Параметрами блока будут:
# - количество каналов на входе
# - количество каналов на выходе
# - глубина блока (2 или 3, по количеству конволюционных слоев)
# - kernel_size и padding
#
class EncoderBlock(nn.Module):
def __init__(self, in_channels, out_channels, depth, kernel_size = 3, padding = 1):
super(EncoderBlock, self).__init__() # инициируем экземляр класса, наследующего от nn.Module
self.layers = nn.ModuleList() # в self.layers будем добавлять слои блока
# дальше реализуем то, что на картинке выше обозначено Conv + Batch Normalization + ReLU
self.layers.append(nn.Conv2d(in_channels = in_channels, out_channels = out_channels, kernel_size = kernel_size, padding = padding))
self.layers.append(nn.BatchNorm2d(out_channels))
self.layers.append(nn.ReLU(inplace=True))
# цикл for помогает использовать один код для блоков как с глубиной 2, так и с глубиной 3
for i in range(depth-1):
self.layers.append(nn.Conv2d(in_channels = out_channels, out_channels = out_channels, kernel_size = kernel_size, padding = padding))
self.layers.append(nn.BatchNorm2d(out_channels))
self.layers.append(nn.ReLU(inplace=True))
self.maxpooling = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True) #добавляем MaxPool с индексами для последующего Unpooling
# Обратите внимание: на вход метод forward() получает карту признаков (х),
# а возвращает карту признаков и индексы для последующего Unpooling
def forward(self, x):
for layer in self.layers:
x = layer(x)
size = x.size()
x, indices = self.maxpooling(x)
return x, indices, size
По аналогии напишите код одного блока декодера.
К карте признаков на входе каждого блока примеяется nn.MaxUnpool2d с индексами из симметричного блока энкодера. Затем повторяется связка Conv + Batch Normalization + ReLU. Количество каналов меняется зеркально блокам энкодера:
- в первом блоке декодера количество каналов не меняется
- во 2-4 блоках декодера количество каналов уменьшается в 2 раза после прохождения последнего конволюционного слоя
- на выходе из последнего блока декодера 1 канал
Обратите внимание, что после последней конволюции последнего блока декодера не применяется батч-нормализация и функция активации.
class DecoderBlock(nn.Module):
def __init__(self, in_channels, out_channels, depth, kernel_size = 3, padding = 1):
super().__init__()
self.layers = nn.ModuleList()
for _ in range(depth):
self.layers.append(nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, padding=padding))
self.layers.append(nn.BatchNorm2d(out_channels))
self.layers.append(nn.ReLU(inplace=True))
in_channels = out_channels
self.maxunpooling = nn.MaxUnpool2d(kernel_size=2, stride=2)
def forward(self, x, indices, output_size):
x = self.maxunpooling(x, indices, output_size)
for layer in self.layers:
x = layer(x)
return x
Соединим блоки энкодера и декодера в модель SegNet:
class SegNet(nn.Module):
def __init__(self, in_channels=3, out_channels = 1, num_features = 64) -> None:
super(SegNet, self).__init__()
# Encoder
self.encoder0 = EncoderBlock(in_channels, num_features, depth=2)
self.encoder1 = EncoderBlock(num_features, num_features * 2, depth=2)
self.encoder2 = EncoderBlock(num_features * 2, num_features * 4, depth=3)
self.encoder3 = EncoderBlock(num_features * 4, num_features * 8, depth=3)
# Encoder bottleneck - количество каналов на входе и на выходе одинаково
self.encoder4 = EncoderBlock(num_features * 8, num_features * 8, depth=3)
# Decoder bottleneck
self.decoder0 = DecoderBlock(num_features * 8, num_features * 8, depth=3)
# Decoder
self.decoder1 = DecoderBlock(num_features * 8, num_features * 4, depth=3)
self.decoder2 = DecoderBlock(num_features * 4, num_features * 2, depth=3)
self.decoder3 = DecoderBlock(num_features * 2, num_features, depth=2)
self.decoder4 = DecoderBlock(num_features, num_features, depth=2)
self.final = nn.Conv2d(num_features, out_channels, kernel_size=1)
def forward(self, x):
# encoder
e0, ind0, size0 = self.encoder0(x)
e1, ind1, size1 = self.encoder1(e0)
e2, ind2, size2 = self.encoder2(e1)
e3, ind3, size3 = self.encoder3(e2)
e4, ind4, size4 = self.encoder4(e3)
# Decoder
d0 = self.decoder0(e4, ind4, size4)
d1 = self.decoder1(d0, ind3, size3)
d2 = self.decoder2(d1, ind2, size2)
d3 = self.decoder3(d2, ind1, size1)
d4 = self.decoder4(d3, ind0, size0)
output = self.final(d4)
return output # no activation
Шаг 5. Тренировка модели [1 балл]¶
Напишите функции для обучения модели.
from tqdm.notebook import tqdm
def fit_one_epoch(model, train_dataloader, optimizer, loss_func):
'''
args:
model - модель для обучения
train_dataloader - loader с выборкой для обучения модели
optimizer - оптимизатор, взятый из модуля `torch.optim`
loss_func - функция потерь, взятая из модуля `torch.nn`
функция возвращает метрику accuracy по эпохе на данных из train_dataloader
'''
model.train()
avg_loss = 0
visualized = False
for X_batch, y_batch in tqdm(train_dataloader):
X_batch = X_batch.to(DEVICE)
y_batch = y_batch.to(DEVICE)
optimizer.zero_grad()
outp = model(X_batch)
prob = torch.sigmoid(outp)
y_pred = (prob > 0.5).long()
if not visualized:
visualized = True
visualize(X_batch, y_batch, y_pred)
loss = loss_func(outp, y_batch)
loss.backward()
optimizer.step()
avg_loss += loss.item()
return avg_loss / len(train_dataloader)
def eval_one_epoch(model, val_dataloader, loss_func):
'''
args:
model - модель для обучения
val_dataloader - loader с валидационной/тестовой выборкой
'''
iou_score = JaccardIndex(threshold=0.5, task="binary", average='none').to(DEVICE)
model.eval()
avg_loss = 0
avg_iou = 0
visualized = False
with torch.no_grad():
for X_batch, y_batch in tqdm(val_dataloader):
X_batch = X_batch.to(DEVICE)
y_batch = y_batch.to(DEVICE)
outp = model(X_batch)
prob = torch.sigmoid(outp)
y_pred = (prob > 0.5).long()
loss = loss_func(outp, y_batch)
iou = iou_score(y_pred, y_batch)
avg_loss += loss.item()
avg_iou += iou
if not visualized:
visualized = True
visualize(X_batch, y_batch, y_pred)
avg_loss = avg_loss / len(val_dataloader)
avg_iou = avg_iou / len(val_dataloader)
return avg_loss, avg_iou
def visualize(X_batch, y_batch, pred, n=1):
batch_size = X_batch.shape[0]
plt.figure(figsize=(10, 3*n))
for i in range(n):
img = X_batch[i].permute(1,2,0).cpu().numpy()
true_mask = y_batch[i].cpu().squeeze().numpy()
pred_mask = pred[i].squeeze().cpu().numpy()
plt.subplot(n, 3, i*3 + 1)
plt.title("Image")
plt.imshow(img)
plt.axis("off")
plt.subplot(n, 3, i*3 + 2)
plt.title("GT mask")
plt.imshow(true_mask, cmap="nipy_spectral")
plt.axis("off")
plt.subplot(n, 3, i*3 + 3)
plt.title("Pred mask")
plt.imshow(pred_mask, cmap="nipy_spectral")
plt.axis("off")
plt.show()
def train_func(model, num_epochs, dataloaders, optimizer, loss_func):
'''
args:
model - модель для обучения
num_epochs - количество эпох
dataloaders - словарь loader'ов с обучающей и валидационной выборками
optimizer - оптимизатор, взятый из модуля `torch.optim`
loss_func - функция потерь, взятая из модуля `torch.nn`
функция возвращает loss на обучающей и валидационной выборках на каждой эпохе, а также метрику IoU на валидационной выборке
'''
model = model.to(DEVICE)
score = {"train_loss": [], "val_loss": [], "val_iou": []}
for epoch in range(num_epochs):
print(f"\nEpoch: {epoch+1}")
loss_train = fit_one_epoch(model = model, train_dataloader = dataloaders['train'], optimizer = optimizer, loss_func = loss_func)
print(f"Loss train: {loss_train}\n")
loss_val, iou_val = eval_one_epoch(model = model, val_dataloader = dataloaders['val'], loss_func=loss_func)
print(f"Loss valid: {loss_val}\n")
print(f"IoU valid: {iou_val}\n")
score['train_loss'].append(loss_train)
score['val_loss'].append(loss_val)
score['val_iou'].append(iou_val)
return model, score
Обучите модель SegNet. В качестве оптимайзера можно взять Adam.
model_baseline = SegNet()
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model_baseline.parameters(), lr = 1e-3)
model_baseline, score_baseline = train_func(model_baseline, 50, loaders, optimizer, criterion)
Epoch: 1
0%| | 0/4 [00:00<?, ?it/s]
Loss train: 0.7088902145624161
0%| | 0/2 [00:00<?, ?it/s]
Loss valid: 0.6806586980819702 IoU valid: 0.0 Epoch: 2
0%| | 0/4 [00:00<?, ?it/s]
Loss train: 0.6382251232862473
0%| | 0/2 [00:00<?, ?it/s]
Loss valid: 0.6685355007648468 IoU valid: 0.0 Epoch: 3
0%| | 0/4 [00:00<?, ?it/s]
Loss train: 0.500316396355629
0%| | 0/2 [00:00<?, ?it/s]
Loss valid: 0.6420924067497253 IoU valid: 0.38382676243782043 Epoch: 4
0%| | 0/4 [00:00<?, ?it/s]
Loss train: 0.4157711789011955
0%| | 0/2 [00:00<?, ?it/s]
Loss valid: 0.6269842386245728 IoU valid: 0.42938777804374695 Epoch: 5
0%| | 0/4 [00:00<?, ?it/s]
Loss train: 0.37223953753709793
0%| | 0/2 [00:00<?, ?it/s]
Loss valid: 0.6746330857276917 IoU valid: 0.3734997510910034 Epoch: 6
0%| | 0/4 [00:00<?, ?it/s]
Loss train: 0.3497515767812729
0%| | 0/2 [00:00<?, ?it/s]
Loss valid: 0.6309003829956055 IoU valid: 0.4095335602760315 Epoch: 7
0%| | 0/4 [00:00<?, ?it/s]
Loss train: 0.3247246891260147
0%| | 0/2 [00:00<?, ?it/s]
Loss valid: 0.6306696534156799 IoU valid: 0.43107739090919495 Epoch: 8
0%| | 0/4 [00:00<?, ?it/s]
Loss train: 0.3140593394637108
0%| | 0/2 [00:00<?, ?it/s]
Loss valid: 0.704652726650238 IoU valid: 0.407431960105896 Epoch: 9
0%| | 0/4 [00:00<?, ?it/s]
Loss train: 0.3083975985646248
0%| | 0/2 [00:00<?, ?it/s]
Loss valid: 0.4111880213022232 IoU valid: 0.5894871950149536 Epoch: 10
0%| | 0/4 [00:00<?, ?it/s]
Loss train: 0.2947194203734398
0%| | 0/2 [00:00<?, ?it/s]
Loss valid: 0.5216855704784393 IoU valid: 0.5018594264984131 Epoch: 11
0%| | 0/4 [00:00<?, ?it/s]
Loss train: 0.28925783932209015
0%| | 0/2 [00:00<?, ?it/s]
Loss valid: 0.30654121935367584 IoU valid: 0.6960251331329346 Epoch: 12
0%| | 0/4 [00:00<?, ?it/s]
Loss train: 0.28432005643844604
0%| | 0/2 [00:00<?, ?it/s]
Loss valid: 0.4548138678073883 IoU valid: 0.5753812789916992 Epoch: 13
0%| | 0/4 [00:00<?, ?it/s]
Loss train: 0.2550861984491348
0%| | 0/2 [00:00<?, ?it/s]
Loss valid: 0.4729755222797394 IoU valid: 0.5680891275405884 Epoch: 14
0%| | 0/4 [00:00<?, ?it/s]
Loss train: 0.24712393060326576
0%| | 0/2 [00:00<?, ?it/s]
Loss valid: 0.4965546280145645 IoU valid: 0.5238682627677917 Epoch: 15
0%| | 0/4 [00:00<?, ?it/s]
Loss train: 0.24329785630106926
0%| | 0/2 [00:00<?, ?it/s]
Loss valid: 0.5635707378387451 IoU valid: 0.5173154473304749 Epoch: 16
0%| | 0/4 [00:00<?, ?it/s]
Loss train: 0.23090828210115433
0%| | 0/2 [00:00<?, ?it/s]
Loss valid: 0.4164176285266876 IoU valid: 0.5875464081764221 Epoch: 17
0%| | 0/4 [00:00<?, ?it/s]
Loss train: 0.2251473069190979
0%| | 0/2 [00:00<?, ?it/s]
Loss valid: 0.40885478258132935 IoU valid: 0.6005829572677612 Epoch: 18
0%| | 0/4 [00:00<?, ?it/s]
Loss train: 0.2389150969684124
0%| | 0/2 [00:00<?, ?it/s]
Loss valid: 0.4225718677043915 IoU valid: 0.6117110848426819 Epoch: 19
0%| | 0/4 [00:00<?, ?it/s]
Loss train: 0.21030806750059128
0%| | 0/2 [00:00<?, ?it/s]
Loss valid: 0.5596248060464859 IoU valid: 0.5397874712944031 Epoch: 20
0%| | 0/4 [00:00<?, ?it/s]
Loss train: 0.23217927664518356
0%| | 0/2 [00:00<?, ?it/s]
Loss valid: 0.26289648562669754 IoU valid: 0.735488772392273 Epoch: 21
0%| | 0/4 [00:00<?, ?it/s]
Loss train: 0.24586381763219833
0%| | 0/2 [00:00<?, ?it/s]
Loss valid: 0.6872997879981995 IoU valid: 0.4611112177371979 Epoch: 22
0%| | 0/4 [00:00<?, ?it/s]
Loss train: 0.22472476214170456
0%| | 0/2 [00:00<?, ?it/s]
Loss valid: 0.34815242886543274 IoU valid: 0.6736955046653748 Epoch: 23
0%| | 0/4 [00:00<?, ?it/s]
Loss train: 0.23693016543984413
0%| | 0/2 [00:00<?, ?it/s]
Loss valid: 0.22739911824464798 IoU valid: 0.7620818614959717 Epoch: 24
0%| | 0/4 [00:00<?, ?it/s]
Loss train: 0.23085859790444374
0%| | 0/2 [00:00<?, ?it/s]
Loss valid: 0.3239000588655472 IoU valid: 0.6869162321090698 Epoch: 25
0%| | 0/4 [00:00<?, ?it/s]
Loss train: 0.22607507184147835
0%| | 0/2 [00:00<?, ?it/s]
Loss valid: 0.5454607158899307 IoU valid: 0.5920796990394592 Epoch: 26
0%| | 0/4 [00:00<?, ?it/s]
Loss train: 0.21723804995417595
0%| | 0/2 [00:00<?, ?it/s]
Loss valid: 0.3899311274290085 IoU valid: 0.670615553855896 Epoch: 27
0%| | 0/4 [00:00<?, ?it/s]
Loss train: 0.19333281368017197
0%| | 0/2 [00:00<?, ?it/s]
Loss valid: 0.2741902843117714 IoU valid: 0.7241814732551575 Epoch: 28
0%| | 0/4 [00:00<?, ?it/s]
Loss train: 0.21157973632216454
0%| | 0/2 [00:00<?, ?it/s]
Loss valid: 0.38126857578754425 IoU valid: 0.6279032230377197 Epoch: 29
0%| | 0/4 [00:00<?, ?it/s]
Loss train: 0.18963951990008354
0%| | 0/2 [00:00<?, ?it/s]
Loss valid: 0.27555491775274277 IoU valid: 0.7111327052116394 Epoch: 30
0%| | 0/4 [00:00<?, ?it/s]
Loss train: 0.18125222250819206
0%| | 0/2 [00:00<?, ?it/s]
Loss valid: 0.17716316878795624 IoU valid: 0.8048577904701233 Epoch: 31
0%| | 0/4 [00:00<?, ?it/s]
Loss train: 0.18316393345594406
0%| | 0/2 [00:00<?, ?it/s]
Loss valid: 0.27888912707567215 IoU valid: 0.7068072557449341 Epoch: 32
0%| | 0/4 [00:00<?, ?it/s]
Loss train: 0.18919847905635834
0%| | 0/2 [00:00<?, ?it/s]
Loss valid: 0.2049947753548622 IoU valid: 0.7776749730110168 Epoch: 33
0%| | 0/4 [00:00<?, ?it/s]
Loss train: 0.16759219393134117
0%| | 0/2 [00:00<?, ?it/s]
Loss valid: 0.34727734327316284 IoU valid: 0.6685512065887451 Epoch: 34
0%| | 0/4 [00:00<?, ?it/s]
Loss train: 0.19003837928175926
0%| | 0/2 [00:00<?, ?it/s]
Loss valid: 0.2719089090824127 IoU valid: 0.7174249887466431 Epoch: 35
0%| | 0/4 [00:00<?, ?it/s]
Loss train: 0.17706162855029106
0%| | 0/2 [00:00<?, ?it/s]
Loss valid: 0.1988501325249672 IoU valid: 0.7730058431625366 Epoch: 36
0%| | 0/4 [00:00<?, ?it/s]
Loss train: 0.16342436894774437
0%| | 0/2 [00:00<?, ?it/s]
Loss valid: 0.1810378059744835 IoU valid: 0.8064121007919312 Epoch: 37
0%| | 0/4 [00:00<?, ?it/s]
Loss train: 0.1796884685754776
0%| | 0/2 [00:00<?, ?it/s]
Loss valid: 0.21672917157411575 IoU valid: 0.7849995493888855 Epoch: 38
0%| | 0/4 [00:00<?, ?it/s]
Loss train: 0.16774186491966248
0%| | 0/2 [00:00<?, ?it/s]
Loss valid: 0.21554523706436157 IoU valid: 0.7795153856277466 Epoch: 39
0%| | 0/4 [00:00<?, ?it/s]
Loss train: 0.1882936768233776
0%| | 0/2 [00:00<?, ?it/s]
Loss valid: 0.22181197255849838 IoU valid: 0.7836636304855347 Epoch: 40
0%| | 0/4 [00:00<?, ?it/s]
Loss train: 0.16846414655447006
0%| | 0/2 [00:00<?, ?it/s]
Loss valid: 0.2256392240524292 IoU valid: 0.7838817834854126 Epoch: 41
0%| | 0/4 [00:00<?, ?it/s]
Loss train: 0.17166605219244957
0%| | 0/2 [00:00<?, ?it/s]
Loss valid: 0.3779662996530533 IoU valid: 0.6729598045349121 Epoch: 42
0%| | 0/4 [00:00<?, ?it/s]
Loss train: 0.16688699647784233
0%| | 0/2 [00:00<?, ?it/s]
Loss valid: 0.24135907739400864 IoU valid: 0.7575369477272034 Epoch: 43
0%| | 0/4 [00:00<?, ?it/s]
Loss train: 0.1529532317072153
0%| | 0/2 [00:00<?, ?it/s]
Loss valid: 0.2166372314095497 IoU valid: 0.7783081531524658 Epoch: 44
0%| | 0/4 [00:00<?, ?it/s]
Loss train: 0.13592077419161797
0%| | 0/2 [00:00<?, ?it/s]
Loss valid: 0.22347378730773926 IoU valid: 0.7527279853820801 Epoch: 45
0%| | 0/4 [00:00<?, ?it/s]
Loss train: 0.1690601073205471
0%| | 0/2 [00:00<?, ?it/s]
Loss valid: 0.31689217686653137 IoU valid: 0.6695342063903809 Epoch: 46
0%| | 0/4 [00:00<?, ?it/s]
Loss train: 0.13234772719442844
0%| | 0/2 [00:00<?, ?it/s]
Loss valid: 0.334345281124115 IoU valid: 0.6826684474945068 Epoch: 47
0%| | 0/4 [00:00<?, ?it/s]
Loss train: 0.1293657124042511
0%| | 0/2 [00:00<?, ?it/s]
Loss valid: 0.21620211750268936 IoU valid: 0.7776235342025757 Epoch: 48
0%| | 0/4 [00:00<?, ?it/s]
Loss train: 0.14344285055994987
0%| | 0/2 [00:00<?, ?it/s]
Loss valid: 0.24861302226781845 IoU valid: 0.734126091003418 Epoch: 49
0%| | 0/4 [00:00<?, ?it/s]
Loss train: 0.14622491039335728
0%| | 0/2 [00:00<?, ?it/s]
Loss valid: 0.1953466534614563 IoU valid: 0.7786290645599365 Epoch: 50
0%| | 0/4 [00:00<?, ?it/s]
Loss train: 0.1356885675340891
0%| | 0/2 [00:00<?, ?it/s]
Loss valid: 0.2108900099992752 IoU valid: 0.7840803861618042
Шаг 6. Инференс [1 балл]¶
После обучения модели напишите функцию теста, воспользуйтесь лучшим чекпоинтом и протестируйте работу модели на тестовой выборке.
def test(model, test_dataloader):
model.eval()
iou_score = JaccardIndex(threshold=0.5, task="binary", average='none').to(DEVICE)
avg_iou = 0
with torch.no_grad():
for X_batch, Y_batch in test_dataloader:
X_batch = X_batch.to(DEVICE)
Y_batch = Y_batch.to(DEVICE)
outp = model(X_batch)
prob = torch.sigmoid(outp)
y_pred = (prob > 0.5).long()
avg_iou += iou_score(y_pred, Y_batch)
return avg_iou / len(test_dataloader)
test_score = test(model_baseline, test_dataloader)
test_score.item()
0.8248778581619263
Задания: Мир других лоссов!¶
Пробуем другие функции потерь [2 балла]¶
В данном разделе вам потребуется имплементировать две функции потерь: DICE и Focal loss.
Dice Loss¶
1. Dice coefficient: Учитывая две маски $X$ и $Y$, общая метрика для измерения расстояния между этими двумя масками задается следующим образом:
$$D(X,Y)=\frac{2|X\cap Y|}{|X|+|Y|}$$
В терминах матрицы ошибок она будет считаться следующим образом:
$$D(X,Y) = \frac{2TP}{2TP + FP + FN}$$
Эта функция не является дифференцируемой, но это необходимое свойство для градиентного спуска. В данном случае мы можем приблизить его с помощью:
$$\mathcal L_D(X,Y) = 1- D(X, Y)$$
Hints (!):
- Не забудьте подумать о численной нестабильности, возникающей в математической формуле при ситуации, когда $\frac{0}{0}$, т.е. вам нужно добавить очень маленькое число, например $\epsilon = 1e^{-8}$, в обе части дроби при подсчете $D(X,Y)$:
$$D(X,Y) = \frac{2TP + ϵ}{2TP + FP + FN + ϵ}$$
Dice метрика(!), не лосс, считается похожим образом как IoU:
2.1. На вход вам приходят logits, т.е. значения от $-∞$ до $∞$. Их переводим в вероятности от 0 до 1 при помощи функции Sigmoid.
2.2. Фиксируем порог, например threshold=0.5, и всему что ниже порога ставим значение 0, всему что выше 1. Получаем предсказанную маску из 0 и 1.
2.3. Считаем TP, FP, FN
2.4. Считаем DICE метрику по формуле
Вы можете прописать для себя функцию dice_score() и сравнить с результатами работы функции из библиотеки torchmetrics.
Но с метрикой есть проблема, что она не дифференцируема, и если вы захотите просто взять и прописать
dice_loss= 1 -dice_score, Pytorch поругается на вас и скажет, что это недифференцируемая метрика. Чтобы посчитать dice_loss делаем следующие шаги:3.1. На вход вам приходят logits, т.е. значения от $-∞$ до $∞$. Их переводим в вероятности от 0 до 1 при помощи функции Sigmoid.
3.2. Здесь нам уже не нужно фиксировать порог, мы просто работаем с вероятностями. Значения вероятностей дифференцируемы и через них будут протекать градиенты.
3.3. Считаем TP, FP, FN также как и в Dice метрике, только вместо маски, подаем вероятности.
3.4. Считаем DICE метрику по формуле
3.5. Считаем лосс как Loss = 1 - DICE
Итак, давайте сначала пропишем dice_score.
def dice_score(logits: torch.Tensor, labels: torch.Tensor, threshold: float = 0.5):
'''
Это именно метрика, не лосс.
'''
eps = 1e-8
prob = torch.sigmoid(logits)
preds = (prob > threshold).int()
TP = (preds * labels).sum()
FP = (preds * (1 - labels)).sum()
FN = ((1 - preds) * labels).sum()
score = (2 * TP + eps) / (2 * TP + FP + FN + eps)
return score
Проверим на корректность функцию dice_score:
from torchmetrics.segmentation import DiceScore
dice = DiceScore(num_classes=1, average='micro').to(DEVICE)
dice(dummpy_sample['logits'].sigmoid() > 0.5, dummpy_sample['labels'].int())
/usr/local/lib/python3.12/dist-packages/torchmetrics/utilities/prints.py:43: UserWarning: DiceScore metric currently defaults to `average=micro`, but will change to`average=macro` in the v1.9 release. If you've explicitly set this parameter, you can ignore this warning. warnings.warn(*args, **kwargs)
tensor(0.6667, device='cuda:0')
assert dice(dummpy_sample['logits'].sigmoid()>0.5, dummpy_sample['labels'].to(int)) == dice_score(dummpy_sample['logits'], dummpy_sample['labels'])
Давайте теперь пропишем лосс и воспользуемся библиотекой segmentation-models-pytorch, чтобы убедиться в корректности нашей функции.
def dice_loss(logits: torch.Tensor, labels: torch.Tensor):
eps = 1e-8
probs = torch.sigmoid(logits)
TP = (probs * labels).sum()
FP = (probs * (1 - labels)).sum()
FN = ((1 - probs) * labels).sum()
dice = (2 * TP + eps) / (2 * TP + FP + FN + eps)
return 1 - dice.mean()
Проверка на корректность:
# проверьте, что у вас установлена библиотека
!pip install -q segmentation-models-pytorch
from segmentation_models_pytorch.losses import DiceLoss
dice_loss_torch = DiceLoss(mode='binary')
dice_loss_torch(dummpy_sample['logits'], dummpy_sample['labels'])
tensor(0.5756, device='cuda:0')
dice_loss(dummpy_sample['logits'], dummpy_sample['labels'])
tensor(0.5756, device='cuda:0')
assert dice_loss_torch(dummpy_sample['logits'], dummpy_sample['labels'].to(int)) == dice_loss(dummpy_sample['logits'], dummpy_sample['labels'])
Focal Loss¶
Окей, мы уже с вами умеем делать BCE loss:
$$\mathcal L_{BCE}(y, \hat y) = -\sum_i \left[y_i\log\sigma(\hat y_i) + (1-y_i)\log(1-\sigma(\hat y_i))\right].$$
Проблема с этой потерей заключается в том, что она имеет тенденцию приносить пользу классу большинства (фоновому) по отношению к классу меньшинства ( переднему). Поэтому обычно применяются весовые коэффициенты к каждому классу:
$$\mathcal L_{wBCE}(y, \hat y) = -\sum_i \alpha_i\left[y_i\log\sigma(\hat y_i) + (1-y_i)\log(1-\sigma(\hat y_i))\right].$$
Традиционно вес $\alpha_i$ определяется как обратная частота класса этого пикселя $i$, так что наблюдения миноритарного класса весят больше по отношению к классу большинства.
Из оригинальной статьи по Focal Loss:
$$p_t = \sigma(\hat y_i)y_i + (1 - \sigma(\hat y_i)) (1-y_i)$$
$$\mathcal L_{focal}(y, \hat y) = (1 - p_t)^{\gamma} \mathcal L_{BCE}(y_i, \hat y_i).$$
$$\mathcal L_{focal}(y, \hat y) = -\sum_i (1 - p_t)^{\gamma} \left[y_i\log\sigma(\hat y_i) + (1-y_i)\log(1-\sigma(\hat y_i))\right].$$
$$\mathcal L_{focal}(y, \hat y) = -\sum_i (1 - (\sigma(\hat y_i)y_i + (1 - \sigma(\hat y_i)) (1-y_i)))^{\gamma} \left[y_i\log\sigma(\hat y_i) + (1-y_i)\log(1-\sigma(\hat y_i))\right].$$
def focal_loss(y_real, y_pred, eps = 1e-6, gamma = 2):
p = torch.sigmoid(y_pred) * y_real + (1 - torch.sigmoid(y_pred))*(1 - y_real)
bce_loss = -(y_real * torch.log(torch.sigmoid(y_pred) + eps) + (1 - y_real) * torch.log(1 - torch.sigmoid(y_pred) + eps))
loss = (1 - p)**gamma * bce_loss
return loss.sum()
Проверка корректности функции:
from torchvision.ops import sigmoid_focal_loss
sigmoid_focal_loss(dummpy_sample['logits'], dummpy_sample['labels'], alpha=-1, gamma=2, reduction='sum').item()
3.616123676300049
assert torch.allclose(sigmoid_focal_loss(dummpy_sample['logits'], dummpy_sample['labels'], alpha=-1, gamma=2, reduction='sum'),
focal_loss(dummpy_sample['labels'], dummpy_sample['logits'], gamma=2.0),
rtol=1e-5,
atol=1e-8)
[BONUS] Мир сегментационных лоссов [5 баллов]¶
В данном блоке предлагаем вам написать одну функцию потерь самостоятельно. Для этого необходимо прочитать статью и имплементировать ее, и провести численное сравнение с предыдущими функциями.
Для изучения была выбрана статья Correlation Maximized Structural Similarity Loss for Semantic Segmentation
Основная идея Structural Similarity Loss:
Вместо того, чтобы смотреть на соответствие единичных пикселей, которые в свою очередь игнорируют зависимость между друг другом, этот метод предлагает смотреть, насколько коррелированы разные локальные участки предсказанной карты сегментации и истинной, и уделять внимание позициям, чьи предсказания приводят к низкой степени линейной корреляции.
Для реализации этой идеи авторы предлагают рассмотреть сумму обычной BCE loss и перевзвешинной BCE loss. К перевзвешенной при этом добавляют множитель ошибки - меру структурного сходства. Также для участков с маленькой ошибкой, перевзвешенная BCE loss будет зануляться, как бы предполагая, что на них все достаточно хорошо и ничего не нужно менять.
Для реализации этой функции потерь нужны следующие формулы:
Общая целевая функция:
$$ L_{all}(y, p) = \lambda L_{ce}(y, p) + (1-\lambda) L_{ssl}(y, p) \tag{17} $$
Классическая кросс-энтропия
$$ L_{ce}(y, p) = - \frac{1}{N} \sum_{n=1}^N \sum_{c=1}^C y_{n,c} \log(p_{n,c}) \tag{1} $$
$e$ — это общая абсолютная ошибка между стандартизованными нормализованными результатами истинных значений ($y$) и предсказаний ($p$). Это мера структурного различия (обратная корреляции). $C_4 = 0.01$ — стабилизирующий фактор.
$$ e = \left|\frac{y - \mu_y + C_4}{\sigma_y + C_4} - \frac{p - \mu_p + C_4}{\sigma_p + C_4}\right| \tag{10} $$
Маска для выбора примеров отбрасывает «легкие примеры» (те, для которых $e$ мало), тем самым реализуя стратегию Online Hard Example Mining (OHEM).
$$ f_{n,c} = 1_{\{e_{n,c} > \beta e_{\max}\}} \tag{11} $$
$L_{ssl}$ — это сигмоидальная кросс-энтропия $L_{ce}$, перевзвешенная структурной ошибкой $e$ и умноженная на маску $f_{n,c}$. $e_{n,c}$ используется как постоянный весовой коэффициент.
$$ L_{ssl}(y_{n,c}, p_{n,c}) = e_{n,c} f_{n,c} L_{ce}(y_{n,c}, p_{n,c}) \tag{12} $$
Итоговая функция потерь SSL по мини-батчу $L_{ssl}$ усредняется только по $M$ выбранным «трудным примерам». $$ L_{ssl}(y, p) = \frac{1}{M} \sum_{n=1}^N \sum_{c=1}^C L_{ssl}(y_{n,c}, p_{n,c}) \tag{13} $$
Для подсчета статистик необходимо будет пройтись гауссовым окном по картам сегментации.
Вспомогатльные формулы:
Локальное среднее: $$ \mu_y = \sum_{i=1}^{k^2} w_i y_i \tag{14} $$
Локальная дисперсия: $$ \sigma^2_y = \sum_{i=1}^{k^2} w_i(y_i - \mu_y)^2 = \sum_{i=1}^{k^2} w_i y^2_i - \mu^2_y \tag{15} $$
${w_i}$ - ${i}$ -тое значение в гауссовском окне (ядре)
from typing import Tuple
import torch
import torch.nn.functional as F
def make_gaussian_kernel(k: int, sigma: float):
assert k % 2 == 1, "kernel size must be odd"
half = k // 2
xs = torch.arange(-half, half+1, dtype=torch.float32)
ys = xs.view(-1, 1)
kernel = torch.exp(-(xs**2 + ys**2) / (2 * sigma**2))
kernel = kernel / kernel.sum()
return kernel
class BinaryStructuralSimilarityLoss(nn.Module):
def __init__(self, window_size=3, tau=0.1, lambda_ce=0.9, c=0.01, eps=1e-8, gauss_sigma=1.5):
super().__init__()
self.k = window_size
self.tau = tau
self.lambda_ce = lambda_ce
self.c = c
self.eps = eps
self.gauss_sigma = gauss_sigma
def __get_kernel(self):
return make_gaussian_kernel(self.k, self.gauss_sigma)
def count_abs_structural_error(self, probs: torch.Tensor, labels: torch.Tensor):
B = probs.shape[0]
p_unf, y_unf = self.make_patches(probs, labels) # (B, k*k, H*W)
kernel = self.__get_kernel()
statistics_p = self.compute_statistics_over_patch(probs, kernel) # (B,1,H,W)
statistics_y = self.compute_statistics_over_patch(labels, kernel)
mean_p, sigma_p = self.adjust_dims(statistics_p, B) # (B, k*k, H*W)
mean_y, sigma_y = self.adjust_dims(statistics_y, B)
z_p = (p_unf - mean_p + self.c) / (sigma_p + self.c)
z_y = (y_unf - mean_y + self.c) / (sigma_y + self.c)
error = torch.abs(z_y - z_p).sum(dim=1) # (B, H*W)
return error
def make_patches(self, probs: torch.Tensor, labels: torch.Tensor):
pad = self.k // 2
p_unf = F.unfold(probs, kernel_size=self.k, padding=pad) # (B, k*k, H*W)
l_unf = F.unfold(labels, kernel_size=self.k, padding=pad) # (B, k*k, H*W)
return p_unf, l_unf
def compute_statistics_over_patch(self, labels: torch.Tensor, kernel: torch.Tensor, eps: float = 1e-8):
pad = self.k // 2
kernel = kernel.view(1, 1, self.k, self.k).to(labels.device)
mean = F.conv2d(labels, kernel, padding=pad) # (B,1,H,W)
E2 = F.conv2d(labels * labels, kernel, padding=pad)
var = torch.clamp(E2 - mean * mean, min=0.0)
sigma = torch.sqrt(var + eps) # (B,1,H,W)
return mean, sigma
def adjust_dims(self, statistics: Tuple[torch.Tensor, torch.Tensor], num_batches: int):
k = self.k
B = num_batches
mean_exp = statistics[0].view(B, 1, -1) # (B, 1, H*W)
sigma_exp = statistics[1].view(B, 1, -1) # (B, 1, H*W)
mean_exp = mean_exp.repeat(1, k*k, 1) # (B, k*k, H*W)
sigma_exp = sigma_exp.repeat(1, k*k, 1)
return mean_exp, sigma_exp
def forward(self, logits: torch.Tensor, labels: torch.Tensor):
probs = torch.sigmoid(logits)
error = self.count_abs_structural_error(probs, labels)
indicator = (error > self.tau * error.max(dim=1, keepdim=True)[0]).float()
bce_full = F.binary_cross_entropy_with_logits(logits, labels, reduction='mean')
bce_pixel = F.binary_cross_entropy_with_logits(logits, labels, reduction='none').view(probs.shape[0], -1) # (B, H*W)
L_ssl_pixel = error * indicator * bce_pixel # (B, H*W)
M = indicator.sum(dim=1).clamp(min=1) # (B,)
L_ssl_batch = (L_ssl_pixel.sum(dim=1) / M) # (B,)
L_ssl = L_ssl_batch.mean()
L_all = self.lambda_ce * bce_full + (1 - self.lambda_ce) * L_ssl
return L_all
ssloss = BinaryStructuralSimilarityLoss()
ssloss(dummpy_sample['logits'], dummpy_sample['labels'])
tensor(0.5462, device='cuda:0')
Проведем численное сравнение с ранее использованными лоссами:
- Рандомный пример
logits = torch.tensor([[[[
-1., 1., 2., 1., -1.,
1., -2., -3., -2., 1.,
2., -3., -4., -3., 2.,
1., -2., -3., -2., 1.,
-1., 1., 2., 1., -1.,
]]]], dtype=torch.float32)
labels = torch.tensor([[[[
0., 1., 1., 1., 0.,
1., 0., 0., 0., 1.,
1., 0., 0., 0., 1.,
1., 0., 0., 0., 1.,
0., 1., 1., 1., 0.,
]]]], dtype=torch.float32)
print("SS Loss:", ssloss(logits, labels))
print("Dice Loss:", dice_loss(logits, labels))
print("Focal Loss:", focal_loss(logits, labels))
print("BCE Loss:", bce_loss(logits, labels))
SS Loss: tensor(0.2240) Dice Loss: tensor(0.1897) Focal Loss: tensor(2.3315) BCE Loss: tensor(4.9871)
- Пример с идеальным предсказанием
labels = torch.tensor([[
[0,0,0,0,0,0,0,0,0,0],
[0,0,0,1,1,1,1,0,0,0],
[0,0,0,1,1,1,1,0,0,0],
[0,0,0,1,1,1,1,0,0,0],
[0,0,0,1,1,1,1,0,0,0],
[0,0,0,1,1,1,1,0,0,0],
[0,0,0,1,1,1,1,0,0,0],
[0,0,0,1,1,1,1,0,0,0],
[0,0,0,0,0,0,0,0,0,0],
[0,0,0,0,0,0,0,0,0,0]
]], dtype=torch.float)
logits = torch.where(labels==1, torch.tensor(10.0), torch.tensor(-10.0)).unsqueeze(0)
# shape: (1,1,10,10)
labels = labels.unsqueeze(0)
print("SS Loss:", ssloss(logits, labels))
print("Dice Loss:", dice_loss(logits, labels))
print("Focal Loss:", focal_loss(logits, labels))
print("BCE Loss:", bce_loss(logits, labels))
SS Loss: tensor(4.1325e-05) Dice Loss: tensor(8.1062e-05) Focal Loss: tensor(-3668.2905) BCE Loss: tensor(0.0046)
- Пример с ужасным предсказанием
labels = torch.tensor([[
[0,0,0,0,0,0,0,0,0,0],
[0,0,0,1,1,1,1,0,0,0],
[0,0,0,1,1,1,1,0,0,0],
[0,0,0,1,1,1,1,0,0,0],
[0,0,0,1,1,1,1,0,0,0],
[0,0,0,1,1,1,1,0,0,0],
[0,0,0,1,1,1,1,0,0,0],
[0,0,0,1,1,1,1,0,0,0],
[0,0,0,0,0,0,0,0,0,0],
[0,0,0,0,0,0,0,0,0,0]
]], dtype=torch.float)
logits = torch.where(labels==1, torch.tensor(-10.0), torch.tensor(+10.0)).unsqueeze(0)
labels = labels.unsqueeze(0)
print("SS Loss:", ssloss(logits, labels))
print("Dice Loss:", dice_loss(logits, labels))
print("Focal Loss:", focal_loss(logits, labels))
print("BCE Loss:", bce_loss(logits, labels))
SS Loss: tensor(22.0742) Dice Loss: tensor(1.0000) Focal Loss: tensor(9086.8047) BCE Loss: tensor(1000.0046)
Т.к. механика всех этих функций потреь отличается, конечно, они не будут одинаковыми, но видно, что для всех примеров они в какой-то степени солидарны
Обучите SegNet на новых лоссах¶
Задание: обучите SegNet на новых лоссах и сравните все три лосса:
- При каком лоссе модель сходится быстрее?
- При каком лоссе модель выдает наилучшую метрику?
Напишите развернутый ответ на вопросы.
model_dice = SegNet()
criterion = DiceLoss(mode='binary')
optimizer = torch.optim.Adam(model_dice.parameters(), lr = 1e-3)
model_dice, score_dice = train_func(model_dice, 50, loaders, optimizer, criterion)
def focal_loss_wrapper(logits, labels):
loss = sigmoid_focal_loss(logits, labels, reduction='none')
return loss.mean()
model_focal = SegNet()
criterion = focal_loss_wrapper
optimizer = torch.optim.Adam(model_focal.parameters(), lr = 1e-3)
model_focal, score_focal = train_func(model_focal, 50, loaders, optimizer, criterion)
def ss_loss_wrapper(logits, labels):
loss = ssloss(logits, labels)
return loss.mean()
model_ssl = SegNet()
criterion = ss_loss_wrapper
optimizer = torch.optim.Adam(model_ssl.parameters(), lr = 1e-3)
model_ssl, score_ssl = train_func(model_ssl, 50, loaders, optimizer, criterion)
def to_float_list(x):
return [float(v) for v in x] # работает и для CUDA tensors
score_baseline['val_iou'] = to_float_list(score_baseline['val_iou'])
score_dice['val_iou'] = to_float_list(score_dice['val_iou'])
score_focal['val_iou'] = to_float_list(score_focal['val_iou'])
score_ssl['val_iou'] = to_float_list(score_ssl['val_iou'])
plt.plot(range(1, 51), score_baseline['val_iou'], label='score baseline')
plt.plot(range(1, 51), score_dice['val_iou'], label='score dice')
plt.plot(range(1, 51), score_focal['val_iou'], label='score focal')
plt.plot(range(1, 51), score_ssl['val_iou'], label='score ssl')
plt.title('IoU на валидационной выборке')
plt.xlabel('epoch')
plt.ylabel('IoU score')
plt.legend()
<matplotlib.legend.Legend at 0x78bc651fb410>
plt.plot(range(1, 51), score_baseline['val_loss'], label='score baseline')
plt.plot(range(1, 51), score_dice['val_loss'], label='score dice')
plt.plot(range(1, 51), score_focal['val_loss'], label='score focal')
plt.plot(range(1, 51), score_ssl['val_loss'], label='score ssl')
plt.ylim(0, 5)
plt.title('Loss на валидационной выборке')
plt.xlabel('epoch')
plt.ylabel('Loss')
plt.legend()
<matplotlib.legend.Legend at 0x78bc64063200>
test_dice = test(model_dice, test_dataloader)
test_focal = test(model_focal, test_dataloader)
test_ssl = test(model_ssl, test_dataloader)
print("Модель с Dice Loss имеет Iou:", test_dice.cpu().item())
print("Модель с Focal Loss имеет Iou:", test_focal.cpu().item())
print("Модель с SS Loss имеет Iou:", test_ssl.cpu().item())
Модель с Dice Loss имеет Iou: 0.8408982753753662 Модель с Focal Loss имеет Iou: 0.8346421718597412 Модель с SS Loss имеет Iou: 0.681615948677063
Наилучшую сходимость показал focal: он сразу же вышел на минимальную ошибку и единственный стабильно ее сохранял.
Наилучшую метрику на тесте выдал Dice, но стоит заметить, что Focal не сильно отстал, а на валидационной выборке был самым стабильным.
Задание: Новая модель!¶
Модель U-Net [2 балла]¶
U-Net — это архитектура нейронной сети, которая получает изображение и выводит его. Первоначально он был задуман для семантической сегментации (как мы ее будем использовать), но он настолько успешен, что с тех пор используется в других контекстах. Получая на вход медицинское изображение, он выведет изображение в оттенках серого, где интенсивность каждого пикселя зависит от вероятности того, что этот пиксель принадлежит интересующей нас области.
У нас в архитектуре все так же существует энкодер и декодер, как в SegNet, но отличительной особеностью данной модели являются skip-conenctions, соединяющие части декодера и энкодера. То есть для того чтобы передать на вход декодера тензор, мы конкатенируем симметричный выход с энкодера и выход предыдущего слоя декодера.
- Ronneberger, Olaf, Philipp Fischer, and Thomas Brox. "U-Net: Convolutional networks for biomedical image segmentation." International Conference on Medical image computing and computer-assisted intervention. Springer, Cham, 2015.
В оригинальной статье авторы не использовали padding внутри модели (это видно по тому, что размеры карты признаков уменьшаются на 2 каждый раз при движении от слоя к слою). При этом размеры входных изображений авторы единоразово увеличили при помощи mirror padding.
В этом домашнем задании вы можете применить альтернативный подход - сохранять размеры карт признаков при помощью padding = 1 во внутренних слоях.
import torch.nn.functional as F
import torch.nn as nn
Для реализации UNet вы можете написать классы блоков энкодера и декодера отдельно, как мы сделали при реализации SegNet.
class UNetEncoder(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
)
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
def forward(self, x):
x_conv = self.conv(x)
x_down = self.pool(x_conv)
return x_conv, x_down
class UNetDecoder(nn.Module):
def __init__(self, in_channels, out_channels, padding=1):
super().__init__()
self.up_conv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
self.conv = nn.Sequential(
nn.Conv2d(out_channels * 2, out_channels, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
)
def forward(self, x, skip):
x = self.up_conv(x)
x = torch.cat([skip, x], dim=1)
return self.conv(x)
class UNet(nn.Module):
def __init__(self, n_class=1):
super().__init__()
self.e0 = UNetEncoder(3, 64)
self.e1 = UNetEncoder(64, 128)
self.e2 = UNetEncoder(128, 256)
self.e3 = UNetEncoder(256, 512)
self.e4 = UNetEncoder(512, 1024)
self.d0 = UNetDecoder(1024, 512)
self.d1 = UNetDecoder(512, 256)
self.d2 = UNetDecoder(256, 128)
self.d3 = UNetDecoder(128, 64)
self.final_conv = nn.Conv2d(64, n_class, kernel_size=1)
def forward(self, x):
x0_conv, x0_down = self.e0(x)
x1_conv, x1_down = self.e1(x0_down)
x2_conv, x2_down = self.e2(x1_down)
x3_conv, x3_down = self.e3(x2_down)
x4_conv, x4_down = self.e4(x3_down)
x = self.d0(x4_conv, x3_conv)
x = self.d1(x, x2_conv)
x = self.d2(x, x1_conv)
x = self.d3(x, x0_conv)
output = self.final_conv(x)
return output
Обучите UNet¶
Задание: обучите UNet на всех трех лоссах: BCE, Dice, Focal и сравните результаты с SegNet:
- Какая модель дает лучшие значения по метрике?
- Какая модель дает лучшие значения по лоссам?
- Какая модель обучается быстрее?
- Сравните визуально результаты SegNet и UNet.
Напишите развернутый ответ на вопросы.
- Обучим U-Net на Focal Loss
unet_model = UNet().to(DEVICE)
criterion = focal_loss_wrapper
optimizer = torch.optim.Adam(unet_model.parameters(), lr = 1e-3)
unet_model, unet_score_focal = train_func(unet_model, 50, loaders, optimizer, criterion)
- Обучим U-Net на Dice Loss
def dice_loss_wrapper(logits, labels):
loss = dice_loss(logits, labels)
return loss.mean()
unet_model_dice = UNet().to(DEVICE)
criterion = dice_loss_wrapper
optimizer = torch.optim.Adam(unet_model_dice.parameters(), lr = 3e-4)
unet_model_dice, unet_score_dice = train_func(unet_model_dice, 50, loaders, optimizer, criterion)
- Обучим U-Net на BCE Loss
unet_model_bce = UNet().to(DEVICE)
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(unet_model_bce.parameters(), lr = 1e-3)
unet_model_bce, unet_score_bce = train_func(unet_model_bce, 50, loaders, optimizer, criterion)
test_unet_dice = test(unet_model_dice, test_dataloader)
test_unet_focal = test(unet_model, test_dataloader)
test_unet_bce = test(unet_model_bce, test_dataloader)
print("Модель с Dice Loss имеет Iou:", test_unet_dice.cpu().item())
print("Модель с Focal Loss имеет Iou:", test_unet_focal.cpu().item())
print("Модель с BCE Loss имеет Iou:", test_unet_bce.cpu().item())
Модель с Dice Loss имеет Iou: 0.3579278588294983 Модель с Focal Loss имеет Iou: 0.025996873155236244 Модель с BCE Loss имеет Iou: 0.5349061489105225
def to_float_list(x):
return [float(v) for v in x]
unet_score_bce['val_iou'] = to_float_list(unet_score_bce['val_iou'])
unet_score_dice['val_iou'] = to_float_list(unet_score_dice['val_iou'])
unet_score_focal['val_iou'] = to_float_list(unet_score_focal['val_iou'])
plt.plot(range(1, 51), unet_score_bce['val_iou'], label='score baseline')
plt.plot(range(1, 51), unet_score_dice['val_iou'], label='score dice')
plt.plot(range(1, 51), unet_score_focal['val_iou'], label='score focal')
plt.title('IoU на валидационной выборке для модели U-Net')
plt.xlabel('epoch')
plt.ylabel('IoU score')
plt.legend()
<matplotlib.legend.Legend at 0x7f73d739ae10>